import warnings
import torch
from torch import nn
from torch.nn import functional as F
from loss_functions.focal_loss import FocalLoss
from loss_functions.focal_loss_ada_gamma import FocalLossAdaptive
from loss_functions.mmce import MMCE, MMCE_weighted
from loss_functions.soft_avuc import SoftAvULoss
from loss_functions.soft_ece import SBEceLoss
from loss_functions.auc_loss_bw import AUCLossBw


# Compute loss with sum of loss instead of average
def compute_loss(settings, logits, targets):
    """Function to select the loss."""

    # Baselines
    if settings.loss_type == "cross-entropy":
        criterion = torch.nn.CrossEntropyLoss(reduction="sum")
        return criterion(logits, targets)

    elif settings.loss_type == "focal-loss":
        criterion = FocalLoss(gamma=settings.gamma_FL)
        return criterion(logits, targets)

    elif settings.loss_type == "focal-loss-ada":
        criterion = FocalLossAdaptive(device=settings.device, gamma=settings.gamma_FL)
        return criterion(logits, targets)

    elif settings.loss_type == "mmce":
        criterion_secondary = MMCE_weighted(device=settings.device)(logits, targets)
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss()(logits, targets)
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL, size_average=True)(
                logits, targets
            )
        return (
            (torch.mul(criterion_secondary, settings.lamda) + criterion_primary)
            * len(targets),
            criterion_primary * len(targets),
            criterion_secondary * len(targets) * settings.lamda,
        )

    elif settings.loss_type == "soft_avuc":
        criterion_secondary = SoftAvULoss(temp=settings.temp_savuc, k=settings.k_savuc)(
            logits=logits,
            labels=targets,
        )
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss()(logits, targets)
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL, size_average=True)(
                logits, targets
            )
        return (
            (torch.mul(criterion_secondary, settings.lamda) + criterion_primary),
            criterion_primary,
            criterion_secondary * settings.lamda,
        )

    elif settings.loss_type == "soft_ece":
        criterion_secondary = SBEceLoss(temp=settings.temp_soft_ece)(
            logits=logits,
            labels=targets,
        )
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss()(logits, targets)
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL, size_average=True)(
                logits, targets
            )
        return (
            (torch.mul(criterion_secondary, settings.lamda) + criterion_primary),
            criterion_primary,
            criterion_secondary * settings.lamda,
        )
    # elif settings.loss_type == "avuc":
    #     criterion_secondary = AvULoss()(
    #         logits=logits,
    #         labels=targets,
    #         optimal_uncertainty_threshold=settings.opt_thresh,
    #     )
    #     if settings.primary_loss_type == "cross-entropy":
    #         criterion_primary = torch.nn.CrossEntropyLoss()(logits, targets)
    #     elif settings.primary_loss_type == "focal-loss":
    #         criterion_primary = FocalLoss(gamma=settings.gamma_FL, size_average=True)(
    #             logits, targets
    #         )
    #     return (
    #         (torch.mul(criterion_secondary, settings.lamda) + criterion_primary)
    #         * len(targets),
    #         criterion_primary * len(targets),
    #         criterion_secondary * len(targets) * settings.lamda,
    #     )

    # Proposed loss function

    elif settings.loss_type == "auc_primary":
        loss = AUCLoss()(logits, targets)
        return loss * len(targets)

    elif settings.loss_type == "auc_primary_bw":
        confidences = F.softmax(logits, dim=1)
        loss = AUCLossBw.apply(confidences, targets, settings.lamda)
        return loss * len(targets)

    # elif settings.loss_type == "auc_primary_bw_logits":
    #     loss = AUCLossBwLogits.apply(logits, targets, settings.lamda)
    #     return loss * len(targets)

    elif settings.loss_type == "auc_secondary_bw":
        confidences = F.softmax(logits, dim=1)
        criterion_secondary = AUCLossBw.apply(confidences, targets, settings.lamda)
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss(reduction="sum")(
                logits, targets
            )
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL)(logits, targets)
        return (
            torch.mul(criterion_secondary, len(targets)) * settings.lamda
            + criterion_primary,
            criterion_primary,
            torch.mul(criterion_secondary, len(targets)) * settings.lamda,
        )
    elif settings.loss_type == "auc_secondary_bw_avg":
        confidences = F.softmax(logits, dim=1)
        criterion_secondary = AUCLossBw.apply(confidences, targets, settings.lamda)
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss()(logits, targets)
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL, size_average=True)(
                logits, targets
            )

        return (
            torch.mul(criterion_secondary, settings.lamda) + criterion_primary,
            criterion_primary,
            torch.mul(criterion_secondary, settings.lamda),
        )
    elif settings.loss_type == "auc_secondary":
        confidences = F.softmax(logits, dim=1)
        criterion_secondary = AUCLoss()(logits, targets)
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss(reduction="sum")(
                logits, targets
            )
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL)(logits, targets)

        return (
            torch.mul(criterion_secondary, len(targets)) * settings.lamda
            + criterion_primary,
            criterion_primary,
            torch.mul(criterion_secondary, len(targets)) * settings.lamda,
        )

    elif settings.loss_type == "auc_secondary_bw_knn":
        confidences = F.softmax(logits, dim=1)
        criterion_secondary = AUCLossBw_knn.apply(confidences, targets, settings.lamda)
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss(reduction="sum")(
                logits, targets
            )
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL)(logits, targets)
        return (
            torch.mul(criterion_secondary, len(targets)) * settings.lamda
            + criterion_primary,
            criterion_primary,
            torch.mul(criterion_secondary, len(targets)) * settings.lamda,
        )
    elif settings.loss_type == "auc_secondary_bw_knn_logits":
        criterion_secondary = AUCLossBw_knn_logits.apply(
            logits, targets, settings.lamda
        )
        if settings.primary_loss_type == "cross-entropy":
            criterion_primary = torch.nn.CrossEntropyLoss(reduction="sum")(
                logits, targets
            )
        elif settings.primary_loss_type == "focal-loss":
            criterion_primary = FocalLoss(gamma=settings.gamma_FL)(logits, targets)
        return (
            torch.mul(criterion_secondary, len(targets)) * settings.lamda
            + criterion_primary,
            criterion_primary,
            torch.mul(criterion_secondary, len(targets)) * settings.lamda,
        )
    else:
        warnings.warn("Loss function is not listed.")
